%load_ext autoreload
%autoreload 2Training Kalman Filter for Results - 21 Feb
from meteo_imp.kalman.training import *
from meteo_imp.kalman.filter import *
from meteo_imp.utils import *
from meteo_imp.data import *
from meteo_imp.gaussian import *
from fastai.tabular.learner import *
from fastai.learner import *
from fastai.callback.all import *
from fastcore.foundation import L
from fastcore.foundation import patch
from pathlib import Path, PosixPath
from meteo_imp.kalman.fastai import show_results
import pandas as pd
import numpy as np
import torch
import random
from pyprojroot import here
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt
from IPython.display import Image, HTML
from tqdm.auto import tqdmfrom fastcore.basics import *show_metrics = Falsereset_seed()hai = pd.read_parquet(hai_big_path)
hai_era = pd.read_parquet(hai_era_big_path)base = here("analysis/results/trained_models")base.mkdir(exist_ok=True)@patch
def add_end(self: PosixPath, end): return self.parent / (self.name + end)def train_or_load(model, dls, lr, n, path, keep=True):
save_models = SaveModelsBatch(times_epoch=5)
csv_logger = CSVLogger(path.add_end("log.csv"))
learn = Learner(dls, model, KalmanLoss(only_gap=True), cbs = [Float64Callback, save_models, csv_logger], metrics=rmse_gap)
items = random.choices(dls.valid.items, k=4)
if path.add_end(".pickle").exists() and keep:
learn.model = torch.load(path.add_end(".pickle"))
display(csv_logger.read_log())
plot = Image(filename=path.add_end("_loss_plot.png"))
display(plot)
else:
learn.fit(lr, n)
torch.save(learn.model, path.add_end(".pickle"))
learn.recorder.plot_loss()
plt.savefig(path.add_end("_loss_plot.png"))
return learn, items def metric_valid(learn, dls=None):
nrmse = []
losses = []
dls = ifnone(dls, learn.dls.valid)
for input, target in tqdm(dls, leave=False):
pred = learn.model(input)
nrmse.append(learn.metrics[0](pred, target))
losses.append(learn.loss_func(pred, target).item())
metric = pd.DataFrame({'loss': losses, 'rmse': nrmse})
return metric.agg(['mean', 'std'])hai.columnsIndex(['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], dtype='object')
Generic model | gap len 3-336 | gap 1 random
dls_A1v = imp_dataloader(
df = hai,
control = hai_era,
var_sel = gen_var_sel(list(hai.columns), n_var=1),
block_len=100+336,
gap_len=gen_gap_len(6, 336),
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=5).cpu()len(hai)227952
len(dls_A1v.train)*20, len(dls_A1v.valid)*20(2080, 520)
model_A1v = KalmanFilterSR.init_local_slope_pca(
n_dim_obs= len(hai.columns),
n_dim_state=len(hai.columns),
n_dim_contr = len(hai_era.columns),
df_pca = None,
pred_only_gap=True)model_A1v.B.shapetorch.Size([1, 18, 14])
learn_A1v, items_A1v = train_or_load(model_A1v, dls_A1v, 3, 1e-3, base / "1_gap_varying_6-336_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 196.765350 | 163.493486 | 0.579074 | 48:33 |
| 1 | 1 | 138.298704 | 123.299909 | 0.490741 | 48:14 |
| 2 | 2 | 113.640141 | 116.746793 | 0.488059 | 39:00 |

learn_A1v, items_A1v = train_or_load(model_A1v, dls_A1v, 1, 1e-4, base / "1_gap_varying_6-336_v2")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 94.869328 | 112.046392 | 0.471249 | 43:59 |

learn_A1v, items_A1v = train_or_load(model_A1v, dls_A1v, 1, 1e-6, base / "1_gap_varying_6-336_v3")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 95.239438 | 104.268073 | 0.467282 | 40:19 |

1 var gap - varying 336 - No Control
model_A1v_nc = KalmanFilterSR.init_local_slope_pca(
n_dim_obs= len(hai.columns),
n_dim_state=len(hai.columns),
n_dim_contr = len(hai_era.columns),
df_pca=None,
pred_only_gap=True,
use_control=False
)learn_A1v_nc, items_A1v_nc = train_or_load(model_A1v_nc, dls_A1v, 3, 1e-3, base / "1_gap_varying_336_no_control_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 219.686355 | 178.506325 | 0.658579 | 37:14 |
| 1 | 1 | 176.039201 | 160.979378 | 0.583213 | 37:00 |
| 2 | 2 | 166.012525 | 158.206468 | 0.574111 | 36:47 |

Short gaps
All variables - 30 all
dls_Aa = imp_dataloader(
df = hai,
control = hai_era,
var_sel = list(hai.columns),
block_len=120,
gap_len=gen_gap_len(6, 30),
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=5
).cpu()dls_Aa = imp_dataloader(hai, hai_era, var_sel = list(hai.columns), block_len=120, gap_len=gen_gap_len(6,30), bs=20, control_lags=[1], n_rep=10).cpu()model_Aa = learn_A1v.model.copy()if show_metrics: display(metric_valid(learn_A1v, dls=dls_Aa.valid))dls_A1v30 = imp_dataloader(
df = hai,
control = hai_era,
var_sel = gen_var_sel(list(hai.columns), n_var=1),
block_len=100+30,
gap_len=gen_gap_len(6, 30),
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=5).cpu()if show_metrics: display(metric_valid(learn_A1v, dls=dls_A1v30.valid))learn_Aa, items_Aa = train_or_load(model_Aa, dls_Aa, 3, 3e-4, base / "All_gap_all_30_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 24.688308 | 21.049544 | 0.365108 | 28:05 |
| 1 | 1 | -7.512621 | -4.152000 | 0.342344 | 27:26 |
| 2 | 2 | -18.230698 | -19.744404 | 0.327594 | 26:15 |

learn_A1v30, items_A1v30 = train_or_load(learn_A1v.model.copy(), dls_A1v30, 3, 3e-4, base / "1_gap_varying_tuned_6-30_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 4.763339 | 3.134267 | 0.276432 | 15:02 |
| 1 | 1 | 2.390429 | 1.721823 | 0.267772 | 16:12 |
| 2 | 2 | 0.780169 | 0.607745 | 0.255524 | 16:05 |

so this is not working …
Varying number of variables missing | short gaps 6-30
dls_Vv30 = imp_dataloader(
df = hai,
control = hai_era,
var_sel = gen_var_sel(list(hai.columns)),
block_len=100+30,
gap_len=gen_gap_len(6, 30),
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=20).cpu()learn_Vv30, items_Vv30 = train_or_load(learn_Aa.model, dls_Vv30, 3, 5e-4, base / "all_varying_gap_varying_len_6-30_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -5.794561 | -4.508800 | 0.213908 | 1:00:55 |
| 1 | 1 | -3.717697 | -5.165062 | 0.205841 | 1:00:26 |
| 2 | 2 | -1.928287 | -6.012112 | 0.202048 | 1:00:21 |

learn_Vv30, items_Vv30 = train_or_load(learn_Aa.model, dls_Vv30, 1, 1e-5, base / "all_varying_gap_varying_len_6-30_v2")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -7.061014 | -6.663726 | 0.192194 | 57:55 |

learn_Vv30, items_Vv30 = train_or_load(learn_Aa.model, dls_Vv30, 1, 1e-5, base / "all_varying_gap_varying_len_6-30_v3")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -7.417934 | -6.799482 | 0.19076 | 1:00:23 |

Random parameters
model_Vv_rand = KalmanFilterSR.init_random(
n_dim_obs= len(hai.columns),
n_dim_state=2*len(hai.columns),
n_dim_contr = 2*len(hai_era.columns),
seed=27,
pred_only_gap=True)learn_Vv_rand, items_Vv_rand = train_or_load(model_Vv_rand, dls_Vv30, 3, 1e-3, base / "rand_all_varying_gap_varying_len_6-30_v1")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | 13.508053 | 14.374478 | 0.448185 | 1:00:35 |
| 1 | 9.766153 | 11.389963 | 0.395332 | 58:51 |
| 2 | 6.503961 | 6.754238 | 0.305433 | 54:16 |

learn_Vv_rand, items_Vv_rand = train_or_load(learn_Vv_rand.model, dls_Vv30, 3, 1e-4, base / "rand_all_varying_gap_varying_len_6-30_v2")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | 5.101230 | 5.345336 | 0.285398 | 1:02:15 |
| 1 | 4.838514 | 4.907970 | 0.281667 | 1:02:44 |
| 2 | 4.571287 | 4.295109 | 0.275344 | 1:04:47 |

learn_Vv_rand, items_Vv_rand = train_or_load(learn_Vv_rand.model, dls_Vv30, 3, 1e-5, base / "rand_all_varying_gap_varying_len_6-30_v3")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | 4.075190 | 4.282635 | 0.275272 | 2:36:25 |
| 1 | 5.134852 | 4.244826 | 0.274858 | 1:17:58 |
| 2 | 3.923739 | 4.185355 | 0.274057 | 1:11:31 |

learn_Vv_rand, items_Vv_rand = train_or_load(learn_Vv_rand.model, dls_Vv30, 3, 1e-5, base / "rand_all_varying_gap_varying_len_6-30_v4")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | 4.100049 | 4.142156 | 0.273928 | 56:48 |
| 1 | 3.712223 | 4.112651 | 0.273566 | 56:59 |
| 2 | 3.942678 | 4.073334 | 0.273349 | 57:03 |

model_Vv_randKalman Filter (9 obs, 18 state, 14 contr)
$A$
| state | x_0 | x_1 | x_2 | x_3 | x_4 | x_5 | x_6 | x_7 | x_8 | x_9 | x_10 | x_11 | x_12 | x_13 | x_14 | x_15 | x_16 | x_17 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| x_0 | 0.8775 | 0.2675 | 0.0937 | 0.6706 | 0.1638 | 0.9272 | 0.2620 | 0.4967 | 0.2630 | 0.1175 | 0.1694 | 0.2100 | 0.4890 | 0.0564 | 0.4760 | 0.7606 | 0.7759 | 0.5243 |
| x_1 | 0.3714 | 0.0426 | 0.2343 | 0.9991 | 0.1775 | 0.6319 | 0.6734 | 0.7937 | 0.6468 | 0.5825 | 0.4599 | 0.7960 | 0.9038 | 0.9735 | 0.6428 | 0.3725 | 0.2052 | 0.0507 |
| x_2 | 0.4448 | 0.5775 | 0.7237 | 0.5927 | 0.3217 | 0.6441 | 0.2801 | 0.9132 | 0.0329 | 0.4856 | 0.9927 | 0.5895 | 0.2611 | 0.9413 | 0.1371 | 0.8726 | 0.5590 | 0.8451 |
| x_3 | 0.1253 | 0.9434 | 0.0462 | 0.2360 | 0.0239 | 0.8950 | 0.7419 | 0.9471 | 0.6690 | 0.1554 | 0.0821 | 0.7309 | 0.7764 | 0.9769 | 0.0196 | 0.0384 | 0.4294 | 0.3438 |
| x_4 | 0.5494 | 0.8238 | 0.9845 | 0.6826 | 0.9001 | 0.3022 | 0.7509 | 0.0926 | 0.0328 | 0.4798 | 0.5335 | 0.0434 | 0.3530 | 0.4157 | 0.7495 | 0.1716 | 0.1980 | 0.4298 |
| x_5 | 0.9201 | 0.6883 | 0.5342 | 0.7847 | 0.3137 | 0.1778 | 0.5838 | 0.9799 | 0.3611 | 0.3155 | 0.7475 | 0.5450 | 0.5641 | 0.2493 | 0.8323 | 0.9723 | 0.1883 | 0.3605 |
| x_6 | 0.5344 | 0.3443 | 0.7696 | 0.3410 | 0.7553 | 0.3177 | 0.0315 | 0.5209 | 0.6514 | 0.3131 | 0.4510 | 0.3550 | 0.4790 | 0.0676 | 0.3606 | 0.7299 | 0.6713 | 0.3134 |
| x_7 | 0.7460 | 0.1291 | 0.4653 | 0.5693 | 0.9906 | 0.8288 | 0.9039 | 0.5240 | 0.6277 | 0.3574 | 0.0076 | 0.6530 | 0.8667 | 0.9368 | 0.8667 | 0.6749 | 0.3526 | 0.6618 |
| x_8 | 0.0837 | 0.7188 | 0.7247 | 0.3211 | 0.4898 | 0.9030 | 0.0358 | 0.1662 | 0.7741 | 0.7937 | 0.7183 | 0.5141 | 0.4918 | 0.2773 | 0.6901 | 0.8565 | 0.3723 | 0.3410 |
| x_9 | 0.4035 | 0.0591 | 0.6836 | 0.8306 | 0.4312 | 0.0210 | 0.0032 | 0.9010 | 0.6741 | 0.3875 | 0.3683 | 0.5337 | 0.0706 | 0.8516 | 0.7304 | 0.8507 | 0.6829 | 0.6900 |
| x_10 | 0.1059 | 0.0500 | 0.5736 | 0.9595 | 0.8101 | 0.7397 | 0.5282 | 0.1294 | 0.2746 | 0.5556 | 0.6463 | 0.0023 | 0.1761 | 0.3391 | 0.3346 | 0.4655 | 0.8172 | 0.4176 |
| x_11 | 0.1349 | 0.0519 | 0.1180 | 0.9767 | 0.1679 | 0.8635 | 0.3753 | 0.9760 | 0.2125 | 0.8049 | 0.2124 | 0.6794 | 0.0037 | 0.9711 | 0.5679 | 0.9474 | 0.8593 | 0.6385 |
| x_12 | 0.8770 | 0.0469 | 0.1582 | 0.6694 | 0.5670 | 0.9794 | 0.6498 | 0.3257 | 0.8462 | 0.7727 | 0.3213 | 0.7318 | 0.3665 | 0.9550 | 0.7188 | 0.2660 | 0.5867 | 0.1134 |
| x_13 | 0.7401 | 0.1982 | 0.4165 | 0.3814 | 0.5263 | 0.6516 | 0.9604 | 0.8996 | 0.8318 | 0.7448 | 0.6912 | 0.5938 | 0.0929 | 0.5298 | 0.2637 | 0.8722 | 0.5430 | 0.2217 |
| x_14 | 0.3495 | 0.3756 | 0.1251 | 0.4052 | 0.0638 | 0.0588 | 0.4379 | 0.4891 | 0.2796 | 0.0740 | 0.2123 | 0.1370 | 0.4477 | 0.3628 | 0.9125 | 0.4047 | 0.8130 | 0.2332 |
| x_15 | 0.8424 | 0.0816 | 0.8791 | 0.3892 | 0.2923 | 0.8603 | 0.1172 | 0.6212 | 0.6087 | 0.6072 | 0.8778 | 0.6758 | 0.5495 | 0.8240 | 0.7461 | 0.1555 | 0.2950 | 0.0365 |
| x_16 | 0.8060 | 0.8602 | 0.9453 | 0.7811 | 0.5495 | 0.5861 | 0.8480 | 0.1940 | 0.9206 | 0.5589 | 0.2148 | 0.1828 | 0.0636 | 0.2885 | 0.9426 | 0.6787 | 0.0080 | 0.7527 |
| x_17 | 0.5032 | 0.5585 | 0.0789 | 0.0409 | 0.3918 | 0.2908 | 0.3802 | 0.0407 | 0.6447 | 0.3241 | 0.8544 | 0.4245 | 0.3987 | 0.4367 | 0.3384 | 0.2285 | 0.7890 | 0.9094 |
$Q$
| state | x_0 | x_1 | x_2 | x_3 | x_4 | x_5 | x_6 | x_7 | x_8 | x_9 | x_10 | x_11 | x_12 | x_13 | x_14 | x_15 | x_16 | x_17 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| x_0 | 1.4985 | 0.1823 | 0.8819 | 1.1215 | 0.8124 | 1.0154 | 0.4127 | 0.4993 | 0.3792 | 1.1249 | 0.3349 | 0.5480 | 1.1208 | 1.0406 | 0.0373 | 0.0713 | 0.0563 | 0.0038 |
| x_1 | 0.1823 | 0.7474 | 0.3388 | 0.8779 | 0.3284 | 0.4377 | 0.5683 | 0.6806 | 0.1517 | 0.6860 | 0.1334 | 0.2869 | 0.3903 | 0.5594 | 0.2687 | 0.7078 | 0.5301 | 0.5990 |
| x_2 | 0.8819 | 0.3388 | 1.5250 | 1.3832 | 1.4418 | 0.7184 | 0.4702 | 1.2013 | 1.1584 | 1.3029 | 1.0110 | 1.0116 | 1.4669 | 1.5483 | 0.6885 | 0.7897 | 1.0060 | 0.8858 |
| x_3 | 1.1215 | 0.8779 | 1.3832 | 2.4360 | 1.6422 | 1.2901 | 1.1545 | 1.7087 | 1.4626 | 2.4015 | 1.3411 | 1.4348 | 2.0194 | 2.1839 | 1.2870 | 1.3528 | 1.5829 | 1.6857 |
| x_4 | 0.8124 | 0.3284 | 1.4418 | 1.6422 | 3.2440 | 1.8293 | 0.6635 | 2.2540 | 1.7788 | 2.2643 | 1.5406 | 1.9412 | 2.5939 | 2.3073 | 2.1935 | 1.1262 | 1.5039 | 1.7599 |
| x_5 | 1.0154 | 0.4377 | 0.7184 | 1.2901 | 1.8293 | 2.6367 | 0.6717 | 2.0377 | 1.2394 | 1.6564 | 1.4530 | 1.5842 | 2.3638 | 2.1887 | 2.0828 | 1.2967 | 0.6836 | 0.8444 |
| x_6 | 0.4127 | 0.5683 | 0.4702 | 1.1545 | 0.6635 | 0.6717 | 1.8814 | 1.4915 | 0.6485 | 1.3643 | 0.6293 | 1.5192 | 1.1450 | 1.7599 | 1.6733 | 1.4465 | 0.7419 | 1.1889 |
| x_7 | 0.4993 | 0.6806 | 1.2013 | 1.7087 | 2.2540 | 2.0377 | 1.4915 | 3.7419 | 2.0936 | 2.3554 | 2.8167 | 2.9808 | 2.6624 | 3.4940 | 3.4288 | 2.7951 | 1.6056 | 2.6631 |
| x_8 | 0.3792 | 0.1517 | 1.1584 | 1.4626 | 1.7788 | 1.2394 | 0.6485 | 2.0936 | 3.2420 | 2.6115 | 2.9707 | 2.0584 | 3.2186 | 2.4394 | 2.7694 | 2.0308 | 2.6026 | 2.2404 |
| x_9 | 1.1249 | 0.6860 | 1.3029 | 2.4015 | 2.2643 | 1.6564 | 1.3643 | 2.3554 | 2.6115 | 4.7867 | 3.5267 | 2.4691 | 3.7434 | 3.6431 | 3.0501 | 2.1870 | 3.3714 | 2.4697 |
| x_10 | 0.3349 | 0.1334 | 1.0110 | 1.3411 | 1.5406 | 1.4530 | 0.6293 | 2.8167 | 2.9707 | 3.5267 | 4.8661 | 2.7361 | 3.9109 | 3.7590 | 4.1276 | 3.2131 | 3.3042 | 2.8676 |
| x_11 | 0.5480 | 0.2869 | 1.0116 | 1.4348 | 1.9412 | 1.5842 | 1.5192 | 2.9808 | 2.0584 | 2.4691 | 2.7361 | 4.5057 | 3.4122 | 4.2709 | 3.9642 | 3.3446 | 2.1979 | 2.7754 |
| x_12 | 1.1208 | 0.3903 | 1.4669 | 2.0194 | 2.5939 | 2.3638 | 1.1450 | 2.6624 | 3.2186 | 3.7434 | 3.9109 | 3.4122 | 6.2732 | 5.0000 | 5.0583 | 4.1123 | 3.9812 | 3.5458 |
| x_13 | 1.0406 | 0.5594 | 1.5483 | 2.1839 | 2.3073 | 2.1887 | 1.7599 | 3.4940 | 2.4394 | 3.6431 | 3.7590 | 4.2709 | 5.0000 | 6.2369 | 5.2005 | 4.6903 | 3.3753 | 4.1217 |
| x_14 | 0.0373 | 0.2687 | 0.6885 | 1.2870 | 2.1935 | 2.0828 | 1.6733 | 3.4288 | 2.7694 | 3.0501 | 4.1276 | 3.9642 | 5.0583 | 5.2005 | 7.8159 | 5.5318 | 4.1568 | 5.4245 |
| x_15 | 0.0713 | 0.7078 | 0.7897 | 1.3528 | 1.1262 | 1.2967 | 1.4465 | 2.7951 | 2.0308 | 2.1870 | 3.2131 | 3.3446 | 4.1123 | 4.6903 | 5.5318 | 5.8308 | 3.6118 | 5.0658 |
| x_16 | 0.0563 | 0.5301 | 1.0060 | 1.5829 | 1.5039 | 0.6836 | 0.7419 | 1.6056 | 2.6026 | 3.3714 | 3.3042 | 2.1979 | 3.9812 | 3.3753 | 4.1568 | 3.6118 | 5.2990 | 4.3170 |
| x_17 | 0.0038 | 0.5990 | 0.8858 | 1.6857 | 1.7599 | 0.8444 | 1.1889 | 2.6631 | 2.2404 | 2.4697 | 2.8676 | 2.7754 | 3.5458 | 4.1217 | 5.4245 | 5.0658 | 4.3170 | 7.1195 |
$b$
| state | offset |
|---|---|
| x_0 | 0.5371 |
| x_1 | 0.6015 |
| x_2 | 0.3190 |
| x_3 | 0.9543 |
| x_4 | 0.5112 |
| x_5 | 0.0341 |
| x_6 | 0.9601 |
| x_7 | 0.1604 |
| x_8 | 0.4499 |
| x_9 | 0.8575 |
| x_10 | 0.2647 |
| x_11 | 0.4293 |
| x_12 | 0.9210 |
| x_13 | 0.5512 |
| x_14 | 0.0890 |
| x_15 | 0.4351 |
| x_16 | 0.3804 |
| x_17 | 0.4879 |
$H$
| variable | x_0 | x_1 | x_2 | x_3 | x_4 | x_5 | x_6 | x_7 | x_8 | x_9 | x_10 | x_11 | x_12 | x_13 | x_14 | x_15 | x_16 | x_17 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| y_0 | 0.5241 | 0.2182 | 0.7958 | 0.7816 | 0.3235 | 0.8518 | 0.4334 | 0.7567 | 0.5235 | 0.2247 | 0.2498 | 0.6324 | 0.0037 | 0.8468 | 0.7664 | 0.0362 | 0.2519 | 0.5872 |
| y_1 | 0.4556 | 0.2781 | 0.0315 | 0.3598 | 0.2876 | 0.8363 | 0.0685 | 0.5543 | 0.9194 | 0.3232 | 0.0243 | 0.2689 | 0.8404 | 0.9788 | 0.9912 | 0.0846 | 0.1129 | 0.0503 |
| y_2 | 0.8881 | 0.6638 | 0.5292 | 0.3452 | 0.4999 | 0.6894 | 0.7628 | 0.4233 | 0.4219 | 0.3110 | 0.1801 | 0.5059 | 0.2597 | 0.9244 | 0.6246 | 0.8295 | 0.5742 | 0.7359 |
| y_3 | 0.2917 | 0.2912 | 0.9906 | 0.3964 | 0.5851 | 0.0647 | 0.3191 | 0.0659 | 0.9295 | 0.0189 | 0.8553 | 0.6701 | 0.6306 | 0.6152 | 0.5295 | 0.9469 | 0.9927 | 0.7433 |
| y_4 | 0.5977 | 0.7385 | 0.9348 | 0.8533 | 0.6523 | 0.7823 | 0.7676 | 0.4763 | 0.6374 | 0.8520 | 0.4391 | 0.5353 | 0.9097 | 0.7429 | 0.2067 | 0.4188 | 0.0382 | 0.9770 |
| y_5 | 0.6669 | 0.7935 | 0.4501 | 0.6770 | 0.0361 | 0.3082 | 0.9436 | 0.8420 | 0.2966 | 0.6996 | 0.8092 | 0.0206 | 0.9509 | 0.0499 | 0.3504 | 0.8491 | 0.5674 | 0.8691 |
| y_6 | 0.4429 | 0.2004 | 0.3868 | 0.9650 | 0.0220 | 0.4891 | 0.0179 | 0.3229 | 0.1670 | 0.6188 | 0.6477 | 0.0439 | 0.3738 | 0.3988 | 0.6175 | 0.9562 | 0.6395 | 0.7886 |
| y_7 | 0.6403 | 0.2487 | 0.6137 | 0.2387 | 0.7919 | 0.1610 | 0.2259 | 0.9336 | 0.8569 | 0.6710 | 0.9067 | 0.1028 | 0.7898 | 0.3126 | 0.5972 | 0.3078 | 0.3259 | 0.5631 |
| y_8 | 0.5374 | 0.9159 | 0.0255 | 0.7863 | 0.0953 | 0.7248 | 0.3355 | 0.1565 | 0.2010 | 0.3647 | 0.3080 | 0.8794 | 0.2877 | 0.2028 | 0.8040 | 0.8565 | 0.2100 | 0.2746 |
$R$
| variable | y_0 | y_1 | y_2 | y_3 | y_4 | y_5 | y_6 | y_7 | y_8 |
|---|---|---|---|---|---|---|---|---|---|
| y_0 | 0.5106 | 0.3847 | 0.4957 | 0.2641 | 0.0725 | 0.3685 | 0.7145 | 0.0334 | 0.6538 |
| y_1 | 0.3847 | 1.1931 | 1.0307 | 0.4238 | 0.1569 | 0.4712 | 0.6801 | 0.9019 | 0.7750 |
| y_2 | 0.4957 | 1.0307 | 1.9865 | 0.5453 | 0.3404 | 0.8202 | 1.0589 | 0.7202 | 1.1494 |
| y_3 | 0.2641 | 0.4238 | 0.5453 | 1.3356 | 0.2462 | 0.9538 | 0.8949 | 1.0863 | 1.4202 |
| y_4 | 0.0725 | 0.1569 | 0.3404 | 0.2462 | 0.9155 | 0.8953 | 0.5091 | 0.3793 | 0.3845 |
| y_5 | 0.3685 | 0.4712 | 0.8202 | 0.9538 | 0.8953 | 1.9467 | 1.7410 | 1.1834 | 1.4452 |
| y_6 | 0.7145 | 0.6801 | 1.0589 | 0.8949 | 0.5091 | 1.7410 | 2.9640 | 1.6060 | 1.6805 |
| y_7 | 0.0334 | 0.9019 | 0.7202 | 1.0863 | 0.3793 | 1.1834 | 1.6060 | 3.7560 | 2.1550 |
| y_8 | 0.6538 | 0.7750 | 1.1494 | 1.4202 | 0.3845 | 1.4452 | 1.6805 | 2.1550 | 3.0331 |
$d$
| variable | offset |
|---|---|
| y_0 | 0.4399 |
| y_1 | 0.8723 |
| y_2 | 0.2250 |
| y_3 | 0.0971 |
| y_4 | 0.6572 |
| y_5 | 0.7544 |
| y_6 | 0.5670 |
| y_7 | 0.7409 |
| y_8 | 0.7357 |
$B$
| state | c_0 | c_1 | c_2 | c_3 | c_4 | c_5 | c_6 | c_7 | c_8 | c_9 | c_10 | c_11 | c_12 | c_13 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| x_0 | 0.0135 | 0.9418 | 0.6751 | 0.3042 | 0.0136 | 0.7803 | 0.2302 | 0.5920 | 0.7610 | 0.8504 | 0.2033 | 0.5990 | 0.8954 | 0.0604 |
| x_1 | 0.2530 | 0.1406 | 0.4280 | 0.1165 | 0.5817 | 0.2896 | 0.4509 | 0.2735 | 0.8425 | 0.5123 | 0.4176 | 0.5773 | 0.3222 | 0.5276 |
| x_2 | 0.4523 | 0.6324 | 0.4716 | 0.0785 | 0.9462 | 0.5346 | 0.9771 | 0.4970 | 0.5893 | 0.5292 | 0.6864 | 0.5196 | 0.8370 | 0.3849 |
| x_3 | 0.4922 | 0.3260 | 0.1990 | 0.6217 | 0.7287 | 0.4589 | 0.8500 | 0.1375 | 0.9457 | 0.8397 | 0.5203 | 0.8416 | 0.1371 | 0.5186 |
| x_4 | 0.4377 | 0.2392 | 0.4949 | 0.4146 | 0.3028 | 0.6810 | 0.1177 | 0.1563 | 0.2588 | 0.8996 | 0.9248 | 0.5575 | 0.2553 | 0.0631 |
| x_5 | 0.5413 | 0.5853 | 0.4166 | 0.9482 | 0.0665 | 0.4683 | 0.0348 | 0.6635 | 0.0501 | 0.1221 | 0.1268 | 0.7322 | 0.3311 | 0.0151 |
| x_6 | 0.1452 | 0.5820 | 0.8673 | 0.3090 | 0.8065 | 0.7325 | 0.1682 | 0.5885 | 0.1180 | 0.4120 | 0.2043 | 0.8200 | 0.5015 | 0.3238 |
| x_7 | 0.3842 | 0.4433 | 0.3052 | 0.4963 | 0.4459 | 0.9266 | 0.6286 | 0.8807 | 0.3252 | 0.0861 | 0.7891 | 0.1666 | 0.1766 | 0.0318 |
| x_8 | 0.8600 | 0.8088 | 0.8600 | 0.5418 | 0.4772 | 0.9634 | 0.3191 | 0.1484 | 0.6377 | 0.0586 | 0.5372 | 0.8380 | 0.8808 | 0.2243 |
| x_9 | 0.4277 | 0.0173 | 0.9436 | 0.3526 | 0.1852 | 0.2433 | 0.8409 | 0.7467 | 0.4969 | 0.2585 | 0.3466 | 0.4240 | 0.1253 | 0.2661 |
| x_10 | 0.9022 | 0.0314 | 0.0804 | 0.7244 | 0.3651 | 0.0938 | 0.8409 | 0.0069 | 0.3613 | 0.6663 | 0.1531 | 0.9582 | 0.1326 | 0.9434 |
| x_11 | 0.8686 | 0.9671 | 0.1879 | 0.7194 | 0.3153 | 0.5075 | 0.6469 | 0.0551 | 0.2449 | 0.5830 | 0.3328 | 0.4071 | 0.2686 | 0.4456 |
| x_12 | 0.5746 | 0.1570 | 0.5606 | 0.7224 | 0.6012 | 0.4299 | 0.0548 | 0.3849 | 0.0750 | 0.4321 | 0.9120 | 0.4023 | 0.5149 | 0.5738 |
| x_13 | 0.1813 | 0.1437 | 0.8099 | 0.2174 | 0.2784 | 0.7365 | 0.5066 | 0.1417 | 0.6935 | 0.0812 | 0.0792 | 0.1286 | 0.6698 | 0.1731 |
| x_14 | 0.3023 | 0.8685 | 0.0737 | 0.2969 | 0.0566 | 0.7863 | 0.9368 | 0.2227 | 0.0272 | 0.9288 | 0.2405 | 0.8415 | 0.4647 | 0.5220 |
| x_15 | 0.2359 | 0.5393 | 0.3662 | 0.9737 | 0.1073 | 0.0926 | 0.9738 | 0.8049 | 0.2272 | 0.4266 | 0.4965 | 0.2811 | 0.5143 | 0.1134 |
| x_16 | 0.8076 | 0.4430 | 0.9223 | 0.0757 | 0.7333 | 0.1208 | 0.4115 | 0.5446 | 0.8064 | 0.5765 | 0.2153 | 0.4235 | 0.2613 | 0.2662 |
| x_17 | 0.4906 | 0.6666 | 0.1782 | 0.4631 | 0.4471 | 0.4886 | 0.6511 | 0.1357 | 0.9547 | 0.8251 | 0.5739 | 0.0537 | 0.9671 | 0.1413 |
$m_0$
| state | mean |
|---|---|
| x_0 | 0.4748 |
| x_1 | 0.0525 |
| x_2 | 0.8524 |
| x_3 | 0.5821 |
| x_4 | 0.7281 |
| x_5 | 0.9879 |
| x_6 | 0.6011 |
| x_7 | 0.4692 |
| x_8 | 0.9031 |
| x_9 | 0.9123 |
| x_10 | 0.6185 |
| x_11 | 0.8070 |
| x_12 | 0.5830 |
| x_13 | 0.5986 |
| x_14 | 0.5898 |
| x_15 | 0.8722 |
| x_16 | 0.7868 |
| x_17 | 0.8305 |
$P_0$
| state | x_0 | x_1 | x_2 | x_3 | x_4 | x_5 | x_6 | x_7 | x_8 | x_9 | x_10 | x_11 | x_12 | x_13 | x_14 | x_15 | x_16 | x_17 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| x_0 | 0.7075 | 0.4575 | 0.1025 | 0.7552 | 0.4678 | 0.4102 | 0.6403 | 0.2264 | 0.0279 | 0.6776 | 0.8178 | 0.4180 | 0.7019 | 0.4730 | 0.7924 | 0.5532 | 0.6499 | 0.0084 |
| x_1 | 0.4575 | 1.3349 | 0.8975 | 1.0693 | 0.9878 | 0.5854 | 0.6734 | 0.6813 | 0.5162 | 1.2483 | 1.1523 | 0.6240 | 0.8392 | 0.6469 | 1.1093 | 0.4286 | 0.6817 | 0.5511 |
| x_2 | 0.1025 | 0.8975 | 1.7778 | 1.4223 | 1.0713 | 0.5301 | 0.4902 | 0.9565 | 0.7575 | 1.1096 | 0.6516 | 0.6540 | 1.3722 | 0.9620 | 0.6709 | 1.0418 | 0.7424 | 0.5799 |
| x_3 | 0.7552 | 1.0693 | 1.4223 | 2.9851 | 1.4541 | 1.0645 | 1.6541 | 0.9605 | 1.5344 | 2.4493 | 1.9649 | 1.5684 | 1.7126 | 1.3214 | 1.6471 | 1.7876 | 1.5110 | 1.4273 |
| x_4 | 0.4678 | 0.9878 | 1.0713 | 1.4541 | 1.6513 | 1.2777 | 1.3102 | 0.7676 | 1.0484 | 1.4509 | 1.6367 | 1.4200 | 1.9104 | 0.8253 | 1.1014 | 0.9313 | 1.6568 | 1.2782 |
| x_5 | 0.4102 | 0.5854 | 0.5301 | 1.0645 | 1.2777 | 2.0910 | 2.0046 | 1.1160 | 1.0262 | 1.7268 | 2.2526 | 2.0053 | 2.1983 | 1.2724 | 1.6174 | 1.5978 | 2.3012 | 1.8972 |
| x_6 | 0.6403 | 0.6734 | 0.4902 | 1.6541 | 1.3102 | 2.0046 | 3.2961 | 1.3658 | 2.1582 | 2.6066 | 2.7869 | 2.4754 | 2.3180 | 1.8782 | 1.9604 | 2.2224 | 3.0209 | 2.9851 |
| x_7 | 0.2264 | 0.6813 | 0.9565 | 0.9605 | 0.7676 | 1.1160 | 1.3658 | 2.2370 | 1.1397 | 2.0217 | 1.7425 | 1.4582 | 1.7065 | 1.2660 | 1.8500 | 1.3857 | 2.3256 | 1.4360 |
| x_8 | 0.0279 | 0.5162 | 0.7575 | 1.5344 | 1.0484 | 1.0262 | 2.1582 | 1.1397 | 3.1976 | 2.4686 | 2.6294 | 2.0770 | 2.1077 | 1.4005 | 1.6343 | 2.0925 | 2.2970 | 2.9207 |
| x_9 | 0.6776 | 1.2483 | 1.1096 | 2.4493 | 1.4509 | 1.7268 | 2.6066 | 2.0217 | 2.4686 | 3.8815 | 3.3607 | 2.4745 | 3.0405 | 2.4247 | 2.8433 | 3.0134 | 3.4315 | 3.3814 |
| x_10 | 0.8178 | 1.1523 | 0.6516 | 1.9649 | 1.6367 | 2.2526 | 2.7869 | 1.7425 | 2.6294 | 3.3607 | 4.7920 | 3.0654 | 3.5766 | 2.7664 | 3.4915 | 3.4792 | 4.0401 | 3.2582 |
| x_11 | 0.4180 | 0.6240 | 0.6540 | 1.5684 | 1.4200 | 2.0053 | 2.4754 | 1.4582 | 2.0770 | 2.4745 | 3.0654 | 3.1407 | 3.3492 | 2.5045 | 2.9033 | 2.5216 | 3.3064 | 2.7858 |
| x_12 | 0.7019 | 0.8392 | 1.3722 | 1.7126 | 1.9104 | 2.1983 | 2.3180 | 1.7065 | 2.1077 | 3.0405 | 3.5766 | 3.3492 | 6.7801 | 4.6939 | 3.9369 | 4.9600 | 4.8216 | 3.7575 |
| x_13 | 0.4730 | 0.6469 | 0.9620 | 1.3214 | 0.8253 | 1.2724 | 1.8782 | 1.2660 | 1.4005 | 2.4247 | 2.7664 | 2.5045 | 4.6939 | 5.2116 | 3.8491 | 4.6837 | 3.8005 | 3.5731 |
| x_14 | 0.7924 | 1.1093 | 0.6709 | 1.6471 | 1.1014 | 1.6174 | 1.9604 | 1.8500 | 1.6343 | 2.8433 | 3.4915 | 2.9033 | 3.9369 | 3.8491 | 5.2003 | 4.0112 | 3.8169 | 3.0359 |
| x_15 | 0.5532 | 0.4286 | 1.0418 | 1.7876 | 0.9313 | 1.5978 | 2.2224 | 1.3857 | 2.0925 | 3.0134 | 3.4792 | 2.5216 | 4.9600 | 4.6837 | 4.0112 | 6.7694 | 5.0850 | 4.5844 |
| x_16 | 0.6499 | 0.6817 | 0.7424 | 1.5110 | 1.6568 | 2.3012 | 3.0209 | 2.3256 | 2.2970 | 3.4315 | 4.0401 | 3.3064 | 4.8216 | 3.8005 | 3.8169 | 5.0850 | 6.7135 | 5.3989 |
| x_17 | 0.0084 | 0.5511 | 0.5799 | 1.4273 | 1.2782 | 1.8972 | 2.9851 | 1.4360 | 2.9207 | 3.3814 | 3.2582 | 2.7858 | 3.7575 | 3.5731 | 3.0359 | 4.5844 | 5.3989 | 7.9999 |
Fine tuning
Fine tune Variable | gap only for one variable | gap len 6-336
fine tune the model to only one variable
from fastcore.basics import *from IPython.display import HTMLvar_learning = {
'TA': [{'lr': 1e-3, 'n': 3}],
'SW_IN': [{'lr': 1e-3, 'n': 4}],
'SW_IN': [{'lr': 1e-3, 'n': 4}],
'LW_IN': [{'lr': 1e-3, 'n': 3}],
'VPD': [{'lr': 1e-3, 'n': 3}],
'WS': [{'lr': 1e-3, 'n': 3}],
'PA': [{'lr': 1e-3, 'n': 3}],
# 'P': [{'lr': 1e-3, 'n': 3}], missing on purpose
'SWC' : [{'lr': 1e-3, 'n': 5}, {'lr': 1e-5, 'n': 1}],
'TS' : [{'lr': 1e-3, 'n': 5}],
}/home/simone/anaconda3/envs/data-science/lib/python3.10/site-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (__class__) that exists in the learner. Use `self.learn.__class__` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
def fine_tune(var_learning, learn):
spec_models = {}
spec_dls = {}
spec_learn = {}
spec_items = {}
for var in tqdm(var_learning.keys()):
display(HTML(f"<h4> {var} | Gap len 6-336 finetune</h4>"))
spec_dls[var] = imp_dataloader(hai, hai_era, var_sel = var, block_len=100+336, gap_len=gen_gap_len(6, 336), bs=20, control_lags=[1], n_rep=3, shifts=gen_shifts(50)).cpu()
if show_metrics:
display(HTML("Metrics generic model"))
display(metric_valid(learn, dls=spec_dls[var].valid))
for i, train in enumerate(var_learning[var]):
lr, n = train
display(HTML(f"train {i}"))
spec_models[var] = learn.model.copy()
spec_learn[var], spec_items[var] = train_or_load(spec_models[var], spec_dls[var], lr, n, base / f"{var}_specialized_gap_6-336_v1_{i}")
plt.show()
return spec_models, spec_dls, spec_learn, spec_itemsspec_models, spec_dls, spec_learn, spec_items = fine_tune(var_learning, learn_A1v)TA | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -73.308072 | -59.729135 | 0.155103 | 23:04 |
| 1 | 1 | -87.049737 | -73.262853 | 0.139768 | 20:37 |
| 2 | 2 | -92.947376 | -82.557740 | 0.131374 | 21:16 |

SW_IN | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 53.445060 | 43.464660 | 0.286705 | 23:25 |
| 1 | 1 | 49.458273 | 42.814378 | 0.285042 | 23:47 |
| 2 | 2 | 48.186476 | 43.087234 | 0.283170 | 22:43 |

LW_IN | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 104.665918 | 106.041969 | 0.414556 | 22:34 |
| 1 | 1 | 101.284686 | 107.526856 | 0.419593 | 23:28 |
| 2 | 2 | 99.767878 | 108.885340 | 0.420613 | 23:43 |

VPD | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 52.165032 | 50.852617 | 0.295505 | 23:31 |
| 1 | 1 | 45.131744 | 36.635116 | 0.272213 | 22:41 |
| 2 | 2 | 41.990330 | 32.914802 | 0.264974 | 23:06 |

WS | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 175.891940 | 246.176982 | 0.761467 | 22:58 |
| 1 | 1 | 165.224262 | 235.823659 | 0.719208 | 23:03 |
| 2 | 2 | 159.586716 | 246.720599 | 0.727202 | 21:00 |

PA | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -99.014311 | -97.738537 | 0.127062 | 21:53 |
| 1 | 1 | -123.031773 | -104.182228 | 0.120468 | 22:45 |
| 2 | 2 | -130.901483 | -133.160406 | 0.104076 | 24:11 |

SWC | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 172.303569 | 200.127139 | 0.599423 | 23:57 |
| 1 | 1 | 132.754943 | 76.512459 | 0.305065 | 22:01 |
| 2 | 2 | 93.005439 | 57.391193 | 0.270830 | 20:58 |

| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 172.303569 | 200.127139 | 0.599423 | 23:57 |
| 1 | 1 | 132.754943 | 76.512459 | 0.305065 | 22:01 |
| 2 | 2 | 93.005439 | 57.391193 | 0.270830 | 20:58 |

TS | Gap len 6-336 finetune
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 68.246805 | 54.157492 | 0.268725 | 23:05 |
| 1 | 1 | 57.107431 | 39.662085 | 0.247396 | 21:34 |
| 2 | 2 | 57.394308 | 53.605346 | 0.260903 | 20:52 |

var_learning2 = {
'TA': [{'lr': 1e-3, 'n': 3}],
'VPD': [{'lr': 1e-3, 'n': 2}],
'PA': [{'lr': 1e-3, 'n': 2}],
'SWC' : [{'lr': 1e-3, 'n': 3}, {'lr': 1e-5, 'n': 1}],
'TS' : [{'lr': 1e-3, 'n': 2}],
}def fine_tune2(var_learning, spec_dls, spec_learn, spec_items):
spec_learn = spec_learn.copy()
for var in tqdm(var_learning.keys()):
display(HTML(f"<h4> {var} | Gap len 6-336 finetune 2 </h4>"))
for i, train in enumerate(var_learning[var]):
lr, n = train['lr'], train['n']
v = train.get('v', 2)
display(HTML(f"train {i}"))
spec_learn[var], _ = train_or_load(spec_learn[var].model, spec_dls[var], n, lr, path=base / f"{var}_specialized_gap_6-336_v{v}_{i}")
plt.show()
return spec_dls, spec_learn, spec_itemsspec_dls2, spec_learn2, spec_items2 = fine_tune2(var_learning2, spec_dls, spec_learn, spec_items)TA | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -97.718002 | -90.171141 | 0.125761 | 21:44 |
| 1 | 1 | -100.493585 | -88.338616 | 0.126856 | 20:23 |
| 2 | 2 | -103.769092 | -73.432097 | 0.135440 | 20:41 |

SW_IN | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 45.373127 | 42.291928 | 0.282366 | 20:34 |
| 1 | 1 | 46.487908 | 41.780965 | 0.283791 | 20:26 |
| 2 | 2 | 46.903155 | 45.729009 | 0.290985 | 20:43 |
| 3 | 3 | 46.917554 | 39.832071 | 0.278899 | 20:29 |

LW_IN | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 98.188426 | 107.646729 | 0.414824 | 20:22 |
| 1 | 1 | 95.761911 | 108.694974 | 0.416340 | 20:30 |
| 2 | 2 | 97.285924 | 106.717304 | 0.415862 | 20:30 |

VPD | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 42.370522 | 31.185345 | 0.254092 | 20:27 |
| 1 | 1 | 37.933048 | 29.823322 | 0.255506 | 20:25 |
| 2 | 2 | 35.904875 | 25.101849 | 0.245977 | 20:28 |

WS | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 158.971839 | 247.131486 | 0.730278 | 20:37 |
| 1 | 1 | 158.584698 | 253.701959 | 0.736434 | 20:34 |
| 2 | 2 | 156.970299 | 263.529629 | 0.749744 | 20:22 |

PA | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -131.680055 | -112.508634 | 0.115626 | 23:56 |
| 1 | 1 | -145.849450 | -95.348580 | 0.114896 | 24:38 |
| 2 | 2 | -145.281960 | -107.879192 | 0.116886 | 24:47 |

SWC | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 35.588624 | -20.680066 | 0.174256 | 24:56 |
| 1 | 1 | 10.683595 | -23.630941 | 0.167582 | 24:32 |
| 2 | 2 | -34.885516 | -76.834793 | 0.124991 | 25:11 |
| 3 | 3 | -47.704883 | 0.936394 | 0.166095 | 24:47 |
| 4 | 4 | -61.871826 | -75.499625 | 0.120001 | 24:45 |

| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -82.762576 | -101.858651 | 0.103941 | 24:41 |

TS | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 40.478151 | 29.605951 | 0.230160 | 25:47 |
| 1 | 1 | 36.143091 | 21.822619 | 0.211396 | 28:18 |
| 2 | 2 | 30.493916 | -0.701642 | 0.181059 | 24:59 |
| 3 | 3 | 26.646244 | 17.138843 | 0.201734 | 27:48 |
| 4 | 4 | 19.231903 | 10.231487 | 0.202251 | 29:09 |

var_learning3 = {
'TA': [{'lr': 1e-5, 'n': 1, 'v': 3}],
'PA': [{'lr': 1e-5, 'n': 1, 'v': 3}],
}spec_dls3, spec_learn3, spec_items3 = fine_tune2(var_learning3, spec_dls, spec_learn, spec_items)TA | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -98.172894 | -86.769272 | 0.131181 | 28:21 |

PA | Gap len 6-336 finetune 2
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -155.452174 | -135.116148 | 0.105047 | 27:51 |

OOOOLLLLLDDDD
spec_models = {}
spec_dls = {}
spec_learn = {}
spec_items = {}
for var in tqdm(list(hai.columns)):
display(HTML(f"<h4> {var} | Gap len 12-336 </h4>"))
spec_dls[var] = imp_dataloader(hai, hai_era, var_sel = var, block_len=100+336, gap_len=gen_gap_len(12, 336), bs=20, control_lags=[1], n_rep=3, shifts=gen_shifts(50)).cpu()
spec_models[var] = learn_A1v.model.copy()
if show_metrics:
display(HTML("Metrics generic model"))
display(metric_valid(learn_A1v, dls=spec_dls[var].valid))
spec_learn[var], spec_items[var] = train_or_load(spec_models[var], spec_dls[var], 3, 1e-3, base / f"{var}_specialized_gap_12-336_v1")
plt.show()TA | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -31.277383 | -59.727349 | 0.155292 | 23:16 |
| 1 | 1 | -57.410593 | -83.423924 | 0.135217 | 24:38 |
| 2 | 2 | -69.265146 | -81.197435 | 0.137000 | 21:50 |

SW_IN | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 72.838760 | 63.622077 | 0.317182 | 22:33 |
| 1 | 1 | 60.514370 | 53.416147 | 0.296670 | 22:32 |
| 2 | 2 | 55.572469 | 48.728888 | 0.287208 | 25:43 |

LW_IN | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 110.011858 | 112.541507 | 0.435194 | 24:28 |
| 1 | 1 | 105.027452 | 106.923631 | 0.417517 | 24:23 |
| 2 | 2 | 102.755247 | 106.489712 | 0.415183 | 24:53 |

VPD | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 67.771759 | 71.121287 | 0.306581 | 25:12 |
| 1 | 1 | 58.432390 | 49.351550 | 0.275882 | 23:40 |
| 2 | 2 | 49.884968 | 41.432351 | 0.264706 | 25:51 |

WS | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 169.196013 | 241.890813 | 0.734134 | 24:18 |
| 1 | 1 | 165.496950 | 235.745148 | 0.711773 | 23:20 |
| 2 | 2 | 161.570979 | 223.982837 | 0.696192 | 22:10 |

PA | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -73.724512 | -103.872036 | 0.124549 | 27:03 |
| 1 | 1 | -111.420458 | -130.842109 | 0.104550 | 30:01 |
| 2 | 2 | -133.129972 | -140.989809 | 0.100966 | 24:07 |

P | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 259.814830 | 216.069186 | 0.602446 | 24:34 |
| 1 | 1 | 248.654373 | 209.606269 | 0.607980 | 23:06 |
| 2 | 2 | 242.001831 | 206.787110 | 0.605378 | 25:11 |

SWC | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 207.952983 | 194.924715 | 0.646890 | 24:35 |
| 1 | 1 | 192.451676 | 184.576861 | 0.595377 | 24:16 |
| 2 | 2 | 169.086573 | 117.395571 | 0.388920 | 22:44 |

TS | Gap len 12-336
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 63.610211 | 40.488381 | 0.255747 | 21:39 |
| 1 | 1 | 57.377678 | 39.280813 | 0.247606 | 21:41 |
| 2 | 2 | 50.239735 | 23.977059 | 0.231418 | 21:34 |

Additional training
spec_learn2 = {}
for var in tqdm(['TA', 'SW_IN', 'WS', 'PA', 'VPD', 'TS', 'SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 </h4> | Training 2"))
spec_learn2[var], _ = train_or_load(spec_learn[var].model.copy(), spec_dls[var], 2, 1e-3, base / f"{var}_specialized_gap_12-336_v2.pickle")
plt.show()
TA | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -82.582890 | -78.921598 | 0.137891 | 21:39 |
| 1 | 1 | -86.119136 | -89.580120 | 0.128288 | 21:18 |

SW_IN | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 52.668345 | 46.674095 | 0.288313 | 21:53 |
| 1 | 1 | 51.089022 | 44.378815 | 0.283374 | 22:28 |

WS | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 164.019476 | 232.919238 | 0.710805 | 22:37 |
| 1 | 1 | 164.185435 | 225.276840 | 0.707446 | 22:50 |

PA | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -147.345027 | -136.658365 | 0.103830 | 22:42 |
| 1 | 1 | -157.427610 | -119.642745 | 0.108054 | 22:36 |

VPD | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 46.489590 | 30.196986 | 0.249948 | 22:27 |
| 1 | 1 | 45.226931 | 46.153450 | 0.273411 | 21:01 |

TS | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 39.863970 | 17.692497 | 0.221164 | 20:40 |
| 1 | 1 | 31.964869 | 3.826767 | 0.209781 | 20:34 |

SWC | Gap len 12-336
| Training 2| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 114.122129 | 96.942318 | 0.352660 | 20:44 |
| 1 | 1 | 67.486398 | 34.946166 | 0.239454 | 20:39 |

PA and VPD are overfitting so repeat training 2 with only one batch
spec_learn3 = {}
for var in tqdm(['PA', 'VPD']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn[var].model.copy(), spec_dls[var], 1, 1e-3, base / f"{var}_specialized_gap_12-336_v3.pickle")
plt.show()
PA | Gap len 12-336 | Training 3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -135.310833 | -106.027287 | 0.120461 | 24:33 |

VPD | Gap len 12-336 | Training 3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 41.435748 | 32.315761 | 0.254877 | 24:13 |

spec_learn3 = {}
for var in tqdm(['TS']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn2[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v3")
plt.show()
TS | Gap len 12-336 | Training 1+2+3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 16.858002 | -7.042792 | 0.189336 | 23:31 |

for var in tqdm(['SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn2[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v3")
plt.show()
SWC | Gap len 12-336 | Training 1+2+3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 3.990423 | -0.629562 | 0.185285 | 24:57 |

for var in tqdm(['SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn2[var].model.copy(), spec_dls[var], 2, 1e-3, base / f"{var}_specialized_gap_12-336_v3_2")
plt.show()
SWC | Gap len 12-336 | Training 1+2+3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -1.621724 | -38.955631 | 0.146205 | 23:45 |
| 1 | 1 | -35.639923 | -57.006918 | 0.130909 | 22:09 |

the training loss is getting worse … so tring with smaller learning rate
for var in tqdm(['SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn3[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v3_3")
plt.show()
SWC | Gap len 12-336 | Training 1+2+3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -75.585425 | -79.693433 | 0.113264 | 25:09 |

for var in tqdm(['SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3+4 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn3[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v3_4")
plt.show()
SWC | Gap len 12-336 | Training 1+2+3+4
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -79.65696 | -81.566511 | 0.113204 | 25:25 |

spec_learn4 = {}
for var in tqdm(['SWC']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3+4 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn3[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v4")
plt.show()
SWC | Gap len 12-336 | Training 1+2+3+4
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -0.957575 | -7.084297 | 0.173571 | 24:35 |

for var in tqdm(['WS']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+2+3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn2[var].model.copy(), spec_dls[var], 1, 1e-5, base / f"{var}_specialized_gap_12-336_v3")
plt.show()
WS | Gap len 12-336 | Training 1+2+3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | 158.930217 | 224.42966 | 0.701698 | 24:04 |

this is overfitting PA
for var in tqdm(['PA']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1++3 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn[var].model.copy(), spec_dls[var], 1, 1e-4, base / f"{var}_specialized_gap_12-336_v3")
plt.show()
PA | Gap len 12-336 | Training 1++3
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -168.729458 | -122.183789 | 0.10763 | 23:40 |

spec_learn4 = {}
for var in tqdm(['PA']):
display(HTML(f"<h4> {var} | Gap len 12-336 | Training 1+3+4 </h4>"))
spec_learn3[var], _ = train_or_load(spec_learn3[var].model.copy(), spec_dls[var], 1, 1e-4, base / f"{var}_specialized_gap_12-336_v4")
plt.show()
PA | Gap len 12-336 | Training 1+3+4
| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -169.911168 | -117.59576 | 0.109397 | 21:33 |

Result visualization
the goal of this section is to figure why the the generic model for all gaps works better then the generic model with gap in only one variable
show_results(learn_A1v, items = spec_items['TA'], hide_no_gap=True)NameError: name 'spec_items' is not defined
show_results(spec_learn2['TA'], items = spec_items['TA'], hide_no_gap=True)NameError: name 'spec_learn2' is not defined
show_results(learn_Aa, items = spec_items['TA'], hide_no_gap=True, dls=learn_A1v.dls)[MeteoImpItem(i=493, shift=-3, var_sel=['TA'], gap_len=263), MeteoImpItem(i=485, shift=-30, var_sel=['TA'], gap_len=318), MeteoImpItem(i=520, shift=29, var_sel=['TA'], gap_len=28), MeteoImpItem(i=418, shift=-10, var_sel=['TA'], gap_len=234)]
no control to correlation is bad
with with_settings(learn_Aa.model, use_control=False):
display(show_results(learn_Aa, items = spec_items['TA'], hide_no_gap=True, dls=learn_A1v.dls))[MeteoImpItem(i=493, shift=-3, var_sel=['TA'], gap_len=263), MeteoImpItem(i=485, shift=-30, var_sel=['TA'], gap_len=318), MeteoImpItem(i=520, shift=29, var_sel=['TA'], gap_len=28), MeteoImpItem(i=418, shift=-10, var_sel=['TA'], gap_len=234)]
there is correlation here but no control,
the error is huggher and the uncertainty is higher
display(show_results(learn_A1v_nc, items = spec_items['TA'], hide_no_gap=True, dls=learn_A1v.dls))[MeteoImpItem(i=493, shift=-3, var_sel=['TA'], gap_len=263), MeteoImpItem(i=485, shift=-30, var_sel=['TA'], gap_len=318), MeteoImpItem(i=520, shift=29, var_sel=['TA'], gap_len=28), MeteoImpItem(i=418, shift=-10, var_sel=['TA'], gap_len=234)]
display(show_results(learn_A1v_nc, items = items_Aa, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30)]
display(show_results(learn_A1v, items = items_Aa, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30)]
this has the control
items_Aa_TA = [MeteoImpItem(i.i, i.shift, 'TA', i.gap_len) for i in items_Aa]display(show_results(learn_Aa, items = items_Aa, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30)]
display(show_results(learn_Aa, items = items_Aa_TA, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA'], gap_len=30)]
so the problem is that this is worse than the one above, even though it should not be the case
display(show_results(learn_A1v, items = items_Aa_TA, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA'], gap_len=30)]
display(show_results(learn_Aa, items = spec_items['TA'], hide_no_gap=True, dls=learn_A1v.dls))[MeteoImpItem(i=493, shift=-3, var_sel=['TA'], gap_len=263), MeteoImpItem(i=485, shift=-30, var_sel=['TA'], gap_len=318), MeteoImpItem(i=520, shift=29, var_sel=['TA'], gap_len=28), MeteoImpItem(i=418, shift=-10, var_sel=['TA'], gap_len=234)]
even for *** longer gaps is the same issue where the generic model trained with gaps in all variables is worse than the generic model with gaps in none
display(show_results(learn_A1v, items = spec_items['TA'], hide_no_gap=True, dls=learn_A1v.dls))[MeteoImpItem(i=493, shift=-3, var_sel=['TA'], gap_len=263), MeteoImpItem(i=485, shift=-30, var_sel=['TA'], gap_len=318), MeteoImpItem(i=520, shift=29, var_sel=['TA'], gap_len=28), MeteoImpItem(i=418, shift=-10, var_sel=['TA'], gap_len=234)]
this is just very bad!
display(show_results(learn_A1v, items = items_Aa, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA', 'SW_IN', 'LW_IN', 'VPD', 'WS', 'PA', 'P', 'SWC', 'TS'], gap_len=30)]
closer look at parameters
import seaborn as snsGeneric 1 gap
def plot_model_params(model):
sns.set(rc={"figure.figsize":(15, 10)})
sns.heatmap(array2df(model.H.squeeze(0)), annot=True, vmin=-1, vmax=1.5, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("H")
plt.show()
sns.set(rc={"figure.figsize":(15, 15)})
sns.heatmap(array2df(model.A.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("A")
plt.show()
sns.set(rc={"figure.figsize":(15, 15)})
sns.heatmap(array2df(model.B.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("B")
plt.show()
sns.heatmap(array2df(model.P0.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("P0")
plt.show()
sns.heatmap(array2df(model.m0.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("m0")
plt.show()
# sns.set(rc={"figure.figsize":(15, 15)})
# sns.heatmap(array2df(model.Q.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
# cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
# plt.title("Q")
# plt.show()
plot_model_params(learn_A1v.model)




plot_model_params(learn_Aa.model)




plot_model_params(learn_A1v30.model)




plot_model_params(spec_learn2['TA'].model)


plot_model_params(learn_Aa.model)


sns.heatmap(array2df((learn_Aa.model.H - learn_A1v30.model.H).squeeze(0)), annot=True, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("diff H")Text(0.5, 1.0, 'diff H')

sns.heatmap(array2df((learn_Aa.model.A - learn_A1v30.model.A).squeeze(0)), annot=True, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("diff A")Text(0.5, 1.0, 'diff A')

sns.heatmap(array2df((learn_Aa.model.B - learn_A1v30.model.B).squeeze(0)), annot=True, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True, cbar=False)
plt.title("diff B")Text(0.5, 1.0, 'diff B')

sns.set(rc={"figure.figsize":(15, 10)})
sns.heatmap(array2df(learn_A1v.model.H.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True)<AxesSubplot: >

sns.set(rc={"figure.figsize":(15, 15)})
sns.heatmap(array2df(learn_A1v.model.A.squeeze(0)), annot=True, vmin=-1, vmax=1, center=0,
cmap=sns.diverging_palette(20, 220, n=200), square=True)<AxesSubplot: >

this is okay but not much better
display(show_results(spec_learn2['TA'], items = items_Aa_TA, hide_no_gap=True, dls=learn_Aa.dls))[MeteoImpItem(i=1591, shift=-24, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1836, shift=-12, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1643, shift=0, var_sel=['TA'], gap_len=30), MeteoImpItem(i=1664, shift=-48, var_sel=['TA'], gap_len=30)]
Finetune gap length
gap_models = {}
gap_dls = {}
gap_learn = {}
gap_items = {}
for gap_len in tqdm([6,24,48,7*48]):
print(f"Gap len: {gap_len}")
gap_dls[gap_len] = imp_dataloader(hai, hai_era, var_sel = gen_var_sel(list(hai.columns), n_var=1), block_len=gap_len+100, gap_len=gap_len, bs=20, control_lags=[1], n_rep=1, shifts=gen_shifts(50)).cpu()
gap_models[gap_len] = learn_A1v.model.copy()
if display_metric: display(metric_valid(learn_A1v, dls=gap_dls[gap_len].valid))
gap_learn[gap_len], gap_items[gap_len] = train_or_load(gap_models[gap_len], gap_dls[gap_len], 3, 2e-5, base / f"gap_1_any_var_gap_len_{gap_len}_v1.pickle")
plt.show()
Gap len: 6
NameError: name 'display_metric' is not defined
Fine tune TA for diff gap lens
as an experiment TA for a gap of 24 fine tuned
dls_TA24 = imp_dataloader(
df = hai,
control = hai_era,
var_sel = 'TA',
block_len=100+24,
gap_len=24,
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=1).cpu()model_TA24 = learn_A1v.model.copy()metric_valid(learn_A1v, dls= dls_TA24.valid)| loss | rmse | |
|---|---|---|
| mean | -2.338763 | 0.184055 |
| std | 2.551372 | 0.034550 |
learn_TA24, items_TA24 = train_or_load(model_TA24, dls_TA24, 1, 1e-4, base / "TA_gap_24_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -2.51402 | -3.892453 | 0.174431 | 02:41 |

learn_TA24, items_TA24 = train_or_load(model_TA24, dls_TA24, 1, 1e-4, base / "TA_gap_24_v1_2")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -4.713286 | -5.854372 | 0.161322 | 02:37 |

learn_TA24, items_TA24 = train_or_load(model_TA24, dls_TA24, 1, 1e-4, base / "TA_gap_24_v1_3")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -6.744339 | -7.335044 | 0.152876 | 02:39 |

learn_TA24, items_TA24 = train_or_load(model_TA24, dls_TA24, 1, 5e-5, base / "TA_gap_24_v1_4")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -7.932901 | -8.451432 | 0.144813 | 02:42 |

learn_TA24, items_TA24 = train_or_load(model_TA24, dls_TA24, 3, 3e-5, base / "TA_gap_24_v1_4")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -7.932901 | -8.451432 | 0.144813 | 02:42 |

metric_valid(learn_TA24, dls= dls_TA24.valid)| loss | rmse | |
|---|---|---|
| mean | -9.049459 | 0.140859 |
| std | 2.104846 | 0.023429 |
dls_TA48 = imp_dataloader(
df = hai,
control = hai_era,
var_sel = 'TA',
block_len=100+48,
gap_len=48,
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=1).cpu()model_TA48 = learn_TA24.model.copy()metric_valid(learn_TA24, dls= dls_TA48.valid)| loss | rmse | |
|---|---|---|
| mean | -11.247340 | 0.170793 |
| std | 8.504741 | 0.035950 |
learn_TA48, items_TA48 = train_or_load(model_TA24, dls_TA24, 2, 2e-4, base / "TA_gap_48_v1")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -11.938200 | -13.136611 | 0.119232 | 02:38 |
| 1 | 1 | -13.676607 | -14.978353 | 0.111058 | 02:38 |

learn_TA48, items_TA48 = train_or_load(model_TA24, dls_TA24, 2, 1e-4, base / "TA_gap_48_v1_2")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -15.563658 | -15.840925 | 0.107003 | 02:44 |
| 1 | 1 | -15.958545 | -16.626610 | 0.103149 | 02:54 |

learn_TA48, items_TA48 = train_or_load(model_TA24, dls_TA24, 2, 1e-4, base / "TA_gap_48_v1_3")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -16.665217 | -17.394887 | 0.099992 | 02:40 |
| 1 | 1 | -17.165897 | -17.974066 | 0.097933 | 02:43 |

learn_TA48, items_TA48 = train_or_load(model_TA24, dls_TA24, 2, 1e-4, base / "TA_gap_48_v1_4")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -18.111798 | -18.427618 | 0.096457 | 02:49 |
| 1 | 1 | -18.591682 | -19.413278 | 0.092474 | 02:52 |

model_TA24_v2 = KalmanFilterSR.init_local_slope_pca(
n_dim_obs= len(hai.columns),
n_dim_state=len(hai.columns),
n_dim_contr = len(hai_era.columns),
df_pca = hai,
pred_only_gap=True)learn_TA24_v2, items_TA24_v2 = train_or_load(model_TA24, dls_TA24, 3, 1e-3, base / "TA_gap_24_v2_1")TA_models = {}
TA_dls = {}
TA_learn = {}
TA_items = {}
for gap_len in tqdm([6,24,48,7*48]):
display(HTML(f"<h4> TA | Gap len: {gap_len} </h4>"))
TA_dls[gap_len] = imp_dataloader(hai, hai_era, var_sel = 'TA', block_len=gap_len+100, gap_len=gap_len, bs=20, control_lags=[1], n_rep=2, shifts=gen_shifts(50)).cpu()
TA_models[gap_len] = learn_A1v.model.copy()
display(metric_valid(learn_A1v, dls=TA_dls[gap_len].valid))
TA_learn[gap_len], TA_items[gap_len] = train_or_load(TA_models[gap_len], TA_dls[gap_len], 4, 1e-4, base / f"TA_gap_len_{gap_len}_v1.pickle")
plt.show()
TA | Gap len: 6
| loss | rmse | |
|---|---|---|
| mean | -2.093156 | 0.102643 |
| std | 0.403933 | 0.025537 |
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -3.623101 | -4.139968 | 0.074969 | 03:57 |
| 1 | -5.193550 | -5.537112 | 0.060948 | 04:14 |
| 2 | -6.309246 | -6.533323 | 0.052087 | 04:33 |
| 3 | -7.058736 | -7.211129 | 0.046813 | 04:22 |

TA | Gap len: 24
| loss | rmse | |
|---|---|---|
| mean | -0.660455 | 0.200111 |
| std | 2.088128 | 0.031036 |
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -4.815664 | -5.862635 | 0.162549 | 05:23 |
| 1 | -8.257145 | -9.265403 | 0.141516 | 05:27 |
| 2 | -10.721531 | -11.764655 | 0.126987 | 05:17 |
| 3 | -13.121913 | -13.537241 | 0.119559 | 05:31 |

TA | Gap len: 48
| loss | rmse | |
|---|---|---|
| mean | 2.231489 | 0.227216 |
| std | 4.624743 | 0.032391 |
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -5.077721 | -6.730453 | 0.189337 | 06:57 |
| 1 | -10.696784 | -12.647214 | 0.166578 | 06:37 |
| 2 | -15.724886 | -16.361526 | 0.153221 | 06:56 |
| 3 | -18.714646 | -19.920346 | 0.142614 | 06:41 |

TA | Gap len: 336
| loss | rmse | |
|---|---|---|
| mean | 30.665105 | 0.248255 |
| std | 18.911530 | 0.021534 |
| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | 18.701340 | -1.243045 | 0.226205 | 17:41 |
| 1 | -0.477330 | -23.318493 | 0.212453 | 17:30 |
| 2 | -19.797870 | -42.075069 | 0.200830 | 47:23 |
| 3 | -36.470670 | -55.291288 | 0.192844 | 18:24 |

TA multiple gap len
dls_TA = imp_dataloader(
df = hai,
control = hai_era,
var_sel = 'TA',
block_len=70+336,
gap_len=gen_gap_len(12,336),
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=3).cpu()model_TA = learn_A1v.model.copy()metric_valid(learn_A1v, dls= dls_TA.valid)| loss | rmse | |
|---|---|---|
| mean | 18.120834 | 0.240742 |
| std | 19.321141 | 0.032218 |
learn_TA, items_TA = train_or_load(model_TA, dls_TA, 2, 5e-4, base / "TA_gap_12-336_v1,")| epoch | train_loss | valid_loss | rmse_gap | time | |
|---|---|---|---|---|---|
| 0 | 0 | -26.097331 | -37.440739 | 0.175528 | 25:36 |
| 1 | 1 | -50.443927 | -66.472712 | 0.149401 | 26:02 |

learn_TA, items_TA = train_or_load(learn_TA.model, dls_TA, 2, 5e-4, base / "TA_gap_12-336_v2")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -70.902275 | -72.341130 | 0.137253 | 24:20 |
| 1 | -78.379055 | -76.962878 | 0.132967 | 23:55 |

learn_TA, items_TA = train_or_load(learn_TA.model, dls_TA, 2, 5e-4, base / "TA_gap_12-336_v3")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -85.106441 | -82.588917 | 0.128094 | 24:21 |
| 1 | -86.267741 | -83.305525 | 0.129777 | 23:29 |

Detailed metrics
def metrics_valid_gap_lens(learn, var, gaps = [6,12,24,48,7*48]):
for gap_len in tqdm(gaps):
dls = imp_dataloader(hai, hai_era, var_sel = var, block_len=gap_len+100, gap_len=gap_len, bs=20, control_lags=[1], n_rep=1, shifts=gen_shifts(50)).cpu()
display(HTML(f"<strong> Metrics | gap len: {gap_len} | Var: {var} </strong>"))
display(metric_valid(learn, dls=dls.valid))/home/simone/anaconda3/envs/data-science/lib/python3.10/site-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (__class__) that exists in the learner. Use `self.learn.__class__` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
metrics_valid_gap_lens(learn_A1v, 'TA')| loss | rmse | |
|---|---|---|
| mean | -2.649933 | 0.095968 |
| std | 0.338268 | 0.018444 |
| loss | rmse | |
|---|---|---|
| mean | -3.243049 | 0.134065 |
| std | 0.710149 | 0.021311 |
| loss | rmse | |
|---|---|---|
| mean | -2.125709 | 0.188222 |
| std | 2.047966 | 0.027886 |
| loss | rmse | |
|---|---|---|
| mean | 0.638268 | 0.217725 |
| std | 6.192707 | 0.031585 |
| loss | rmse | |
|---|---|---|
| mean | 21.469734 | 0.242877 |
| std | 9.829365 | 0.010609 |
metrics_valid_gap_lens(learn_A1v_nc, 'TA')| loss | rmse | |
|---|---|---|
| mean | -3.430541 | 0.055689 |
| std | 0.097171 | 0.009796 |
| loss | rmse | |
|---|---|---|
| mean | -3.790838 | 0.113533 |
| std | 0.833381 | 0.030206 |
| loss | rmse | |
|---|---|---|
| mean | -0.992221 | 0.187918 |
| std | 4.834171 | 0.068321 |
| loss | rmse | |
|---|---|---|
| mean | 5.597832 | 0.235436 |
| std | 12.132120 | 0.075510 |
| loss | rmse | |
|---|---|---|
| mean | 96.691128 | 0.293088 |
| std | 98.837141 | 0.085536 |
metrics_valid_gap_lens(spec_learn['TA'], 'TA')/home/simone/anaconda3/envs/data-science/lib/python3.10/site-packages/fastai/callback/core.py:69: UserWarning: You are shadowing an attribute (__class__) that exists in the learner. Use `self.learn.__class__` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
| loss | rmse | |
|---|---|---|
| mean | -6.332682 | 0.049233 |
| std | 0.226535 | 0.007010 |
| loss | rmse | |
|---|---|---|
| mean | -10.478126 | 0.072882 |
| std | 0.762210 | 0.010581 |
| loss | rmse | |
|---|---|---|
| mean | -15.977039 | 0.103628 |
| std | 4.186457 | 0.024085 |
| loss | rmse | |
|---|---|---|
| mean | -25.748188 | 0.125292 |
| std | 11.229092 | 0.031731 |
| loss | rmse | |
|---|---|---|
| mean | -136.930894 | 0.152361 |
| std | 61.110281 | 0.026692 |
metrics_valid_gap_lens(spec_learn["SWC"], 'SWC', gaps=[6,12,24,30])| loss | rmse | |
|---|---|---|
| mean | -4.637960 | 0.051318 |
| std | 0.299619 | 0.012608 |
| loss | rmse | |
|---|---|---|
| mean | -6.503471 | 0.078820 |
| std | 0.751687 | 0.020815 |
| loss | rmse | |
|---|---|---|
| mean | -6.256985 | 0.114699 |
| std | 1.533914 | 0.030592 |
| loss | rmse | |
|---|---|---|
| mean | -3.957636 | 0.144434 |
| std | 2.228843 | 0.032844 |
metrics_valid_gap_lens(spec_learn3["SWC"], 'SWC', gaps=[6,12,24,30])| loss | rmse | |
|---|---|---|
| mean | -7.409364 | 0.040791 |
| std | 0.602997 | 0.008078 |
| loss | rmse | |
|---|---|---|
| mean | -14.166100 | 0.049997 |
| std | 1.298232 | 0.009736 |
| loss | rmse | |
|---|---|---|
| mean | -26.025482 | 0.058914 |
| std | 2.035570 | 0.009232 |
| loss | rmse | |
|---|---|---|
| mean | -31.526816 | 0.064446 |
| std | 2.601954 | 0.008842 |
metrics_valid_gap_lens(learn_Aa, 'SWC', gaps=[6,12,24,30])| loss | rmse | |
|---|---|---|
| mean | -10.621700 | 0.020779 |
| std | 1.637564 | 0.009605 |
| loss | rmse | |
|---|---|---|
| mean | -21.012826 | 0.024119 |
| std | 1.456359 | 0.007166 |
| loss | rmse | |
|---|---|---|
| mean | -36.451721 | 0.034461 |
| std | 4.163576 | 0.011231 |
| loss | rmse | |
|---|---|---|
| mean | -43.786023 | 0.038519 |
| std | 5.267340 | 0.010467 |
this is pretty weird the model with gaps in all variables is performing better that the one with only partial gaps ….
let’s so some finetuning
SWC | Gap len 30
dls_SWC_30 = imp_dataloader(
df = hai,
control = hai_era,
var_sel = 'SWC',
block_len=120,
gap_len=30,
bs=20,
control_lags=[1],
shifts=gen_shifts(50),
n_rep=1
).cpu()model_SWC_30 = spec_learn['SWC'].copy()if show_metrics: metric_valid(model_SWC_30, dls=model_SWC_30.valid)learn_SWC_30, items_SWC_30 = train_or_load(model_SWC_30, dls_SWC_30, 2, 3e-4, base / "SWC_gap_30_v1")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -10.998339 | -15.672505 | 0.099315 | 02:51 |
| 1 | -17.301334 | -22.166437 | 0.085715 | 02:51 |

learn_SWC_30, items_SWC_30 = train_or_load(model_SWC_30, dls_SWC_30, 2, 1e-3, base / "SWC_gap_30_v2")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -27.642977 | -34.401863 | 0.056773 | 02:56 |
| 1 | -33.920003 | -38.631508 | 0.050980 | 02:57 |

learn_SWC_30, items_SWC_30 = train_or_load(model_SWC_30, dls_SWC_30, 2, 1e-3, base / "SWC_gap_30_v3")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -38.766184 | -42.612718 | 0.041158 | 02:55 |
| 1 | -39.593180 | -42.984206 | 0.041554 | 02:52 |

learn_SWC_30, items_SWC_30 = train_or_load(model_SWC_30, dls_SWC_30, 2, 1e-3, base / "SWC_gap_30_v3")| epoch | train_loss | valid_loss | rmse_gap | time |
|---|---|---|---|---|
| 0 | -38.766184 | -42.612718 | 0.041158 | 02:55 |
| 1 | -39.593180 | -42.984206 | 0.041554 | 02:52 |

show_results(learn_SWC_30, hide_no_gap=True, items=items_SWC_30)[MeteoImpItem(i=1674, shift=30, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1693, shift=46, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1570, shift=-83, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1682, shift=54, var_sel=['SWC'], gap_len=30)]
show_results(learn_Aa, hide_no_gap=True, items=items_SWC_30)[MeteoImpItem(i=1674, shift=30, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1693, shift=46, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1570, shift=-83, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1682, shift=54, var_sel=['SWC'], gap_len=30)]
show_results(learn_A1v, hide_no_gap=True, items=items_SWC_30, dls=dls_SWC_30)[MeteoImpItem(i=1674, shift=30, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1693, shift=46, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1570, shift=-83, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1682, shift=54, var_sel=['SWC'], gap_len=30)]
show_results(spec_learn['SWC'], hide_no_gap=True, items=spec_items['SWC'], dls=spec_dls['SWC'])[MeteoImpItem(i=510, shift=10, var_sel=['SWC'], gap_len=243), MeteoImpItem(i=504, shift=0, var_sel=['SWC'], gap_len=109), MeteoImpItem(i=494, shift=-94, var_sel=['SWC'], gap_len=116), MeteoImpItem(i=481, shift=-17, var_sel=['SWC'], gap_len=148)]
show_results(learn_SWC_30)[MeteoImpItem(i=1892, shift=-50, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1667, shift=68, var_sel=['SWC'], gap_len=30), MeteoImpItem(i=1761, shift=41, var_sel=['SWC'], gap_len=30)]